Author Topic Model

Implementation as described in http://mimno.infosci.cornell.edu/info6150/readings/398.pdf


In [1]:
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

In [69]:
np.ones([10])[[2,3,4], np.newaxis].repeat(5, axis=1)

_a = np.random.randint(0,10,size=(4,5)) * 1.0
print _a
_b = np.array([1,2,1,2,3])
print _b
_c = np.array(["V%s" % k for k in xrange(10)])
print _c
print (_a * _b)
_c[np.argsort(_a, axis=0)[::-1, :][:2, :]]


[[ 0.  7.  7.  7.  5.]
 [ 1.  7.  5.  1.  0.]
 [ 5.  5.  6.  9.  2.]
 [ 4.  6.  0.  7.  2.]]
[1 2 1 2 3]
['V0' 'V1' 'V2' 'V3' 'V4' 'V5' 'V6' 'V7' 'V8' 'V9']
[[  0.  14.   7.  14.  15.]
 [  1.  14.   5.   2.   0.]
 [  5.  10.   6.  18.   6.]
 [  4.  12.   0.  14.   6.]]
Out[69]:
array([['V2', 'V1', 'V0', 'V2', 'V0'],
       ['V3', 'V0', 'V2', 'V3', 'V3']], 
      dtype='|S2')

In [70]:
class AuthorTopicModel(object):
    """Implementation of an author topic model.
    Generates each document based on a topic and author pair,
    This is used to generate a word in the document.
    """
    
    def __init__(self, K, doc_word_matrix, doc_author_matrix, vocab, authornames, alpha=0.1, beta=0.5):
        """Constructor for the function
        K: number of topics
        doc_word_matrix: list of documents each represented as list of word ids
        doc_author_matrix: list of documents each represented as list of author ids
        vocab: dictionary of word ids mapped to word strings
        authornames: dictionary of author ids mapped to author names
        alpha: Author topic diritchelet parameter
        beta: Word topic diritchelet parameter        
        """
        self.K = K
        self.doc_word_matrix = doc_word_matrix
        self.doc_author_matrix = doc_author_matrix
        self.N = len(doc_word_matrix)
        self.vocab = vocab
        self.W = len(vocab)
        self.authornames= authornames
        self.A = len(authornames)
        self.alpha = alpha
        self.beta = beta
        
        self.W_T = np.zeros([self.W, self.K])
        self.A_T = np.zeros([self.A, self.K])
        
        self.T_marginal = np.zeros(self.K)
        self.A_marginal = np.zeros(self.A)
        
        self.T_assigned = []
        self.A_assigned = []
        self._populate_vars()
        
    def _populate_vars(self):
        """Populate the variables with the initial data
        """
        for di, doc in enumerate(self.doc_word_matrix):
            auth = self.doc_author_matrix[di]
            self.T_assigned.append([])
            self.A_assigned.append([])
            for wi, w in enumerate(doc):
                # Randomly assign a topic to the word
                z = np.random.choice(self.K)
                # Randomly assign a topic to a random author
                a = np.random.choice(self.A)
                # Update all the word, topic and author topic counts
                self.W_T[w,z] += 1
                self.A_T[a,z] += 1
                # Update marginals
                self.T_marginal[z] += 1
                self.A_marginal[a] += 1
                # Record the sampled topic and author assignments
                self.T_assigned[-1].append(z)
                self.A_assigned[-1].append(a)
    
    def gibbs_sampling(self):
        """Perform single gibbs sampling step
        """
        for di, doc in enumerate(self.doc_word_matrix):
            auth = self.doc_author_matrix[di]
            for wi, w in enumerate(doc):
                # Extract the previous assignment
                z = self.T_assigned[di][wi]
                a = self.A_assigned[di][wi]
                # Substract the previous assignments
                # Update all the word, topic and author topic counts
                self.W_T[w,z] -= 1
                self.A_T[a,z] -= 1
                # Update marginals
                self.T_marginal[z] -= 1
                self.A_marginal[a] -= 1
                
                # Find probability of the word w belonging to each topic
                phi = (self.W_T[w,:] + self.beta) / (self.T_marginal + self.W*self.beta)
                # Find probability of each author in auth belonging to each topic
                theta = (self.A_T[auth,:] + self.alpha) / (self.A_marginal[auth, np.newaxis] + self.W*self.alpha)
                # Joint probability of word and author for all topics
                pdf = theta*phi
                pdf = pdf / pdf.sum()
                # Index of authors and topics
                auth_t_pairs = [(i,j) for i in auth for j in xrange(self.K)]
                # Sample an author and topic pair for the word
                #print auth_t_pairs, p.flatten()
                idx = np.random.choice(range(len(auth_t_pairs)), p=pdf.flatten())
                a, z = auth_t_pairs[idx]
                # Update all the word, topic and author topic counts
                self.W_T[w,z] += 1
                self.A_T[a,z] += 1
                # Update marginals
                self.T_marginal[z] += 1
                self.A_marginal[a] += 1
                # Record the sampled topic and author assignments
                self.T_assigned[di][wi] = z
                self.A_assigned[di][wi] = a
    
    def perform_iterations(self, burnin=100, max_iters=10, print_every=5):
        """Perform max_iters of gibbs sampling steps
        """
        print "Performing %s gibbs sampling iterations burn in phase" % burnin
        for i in xrange(burnin):
            self.gibbs_sampling()
        print "Burn in complete"
        print "Topic proportions: %s" % (self.T_marginal * 1. / self.T_marginal.sum())
        print "Author proportions: %s" % (self.A_marginal * 1. / self.A_marginal.sum())
        print "W_T[w,z]:\n%s" % (self.W_T * 1./ self.W_T.sum())
        print "A_T[a,z]:\n%s" % (self.A_T * 1./ self.A_T.sum())
        print "Performing %s gibbs sampling iterations" % max_iters
        for i in xrange(max_iters):
            if i%print_every == 0:
                print "Iter %s:" % i
                self.gibbs_sampling()
                print "Topic proportions: %s" % (self.T_marginal * 1. / self.T_marginal.sum())
                print "Author proportions: %s" % (self.A_marginal * 1. / self.A_marginal.sum())
                print "W_T[w,z]:\n%s" % (self.W_T * 1./ self.W_T.sum())
                print "A_T[a,z]:\n%s" % (self.A_T * 1./ self.A_T.sum())
        print "Done"
    
    def show_topics(self, topn_w=3, topn_a=3):
        print "Top %s words per topic" % topn_w
        print self.vocab[np.argsort(self.W_T, axis=0)[::-1, :][:topn_w, :]]
        print "Top %s authors per topic" % topn_a
        print self.authornames[np.argsort(self.A_T, axis=0)[::-1, :][:topn_a, :]]

In [71]:
K = 3
doc_word_matrix = [[0,0,0,1,2,1],
                  [0,0,1,1,1,1,1],
                  [2,2,2,3,3,3],
                  [0,2,2,2,3,3,1],
                  [4,4,4,0,5,5,2],
                  [4,5,5,3,0,5,5,1]]
doc_author_matrix = [[0,1],
                     [1,2],
                     [0,1,2],
                     [2,3],
                     [4,5,3],
                     [4,5]]
vocab = np.array(["V%s" % k for k in xrange(6)])
authornames = np.array(["A%s" % k for k in xrange(6)])

K, doc_word_matrix, doc_author_matrix, vocab, authornames


Out[71]:
(3,
 [[0, 0, 0, 1, 2, 1],
  [0, 0, 1, 1, 1, 1, 1],
  [2, 2, 2, 3, 3, 3],
  [0, 2, 2, 2, 3, 3, 1],
  [4, 4, 4, 0, 5, 5, 2],
  [4, 5, 5, 3, 0, 5, 5, 1]],
 [[0, 1], [1, 2], [0, 1, 2], [2, 3], [4, 5, 3], [4, 5]],
 array(['V0', 'V1', 'V2', 'V3', 'V4', 'V5'], 
       dtype='|S2'),
 array(['A0', 'A1', 'A2', 'A3', 'A4', 'A5'], 
       dtype='|S2'))

In [72]:
atm = AuthorTopicModel(K, doc_word_matrix, doc_author_matrix, vocab, authornames)

In [73]:
atm.perform_iterations(max_iters=10)


Performing 100 gibbs sampling iterations burn in phase
Burn in complete
Topic proportions: [ 0.24390244  0.51219512  0.24390244]
Author proportions: [ 0.19512195  0.14634146  0.14634146  0.14634146  0.2195122   0.14634146]
W_T[w,z]:
[[ 0.02439024  0.14634146  0.02439024]
 [ 0.          0.12195122  0.09756098]
 [ 0.04878049  0.04878049  0.09756098]
 [ 0.          0.12195122  0.02439024]
 [ 0.09756098  0.          0.        ]
 [ 0.07317073  0.07317073  0.        ]]
A_T[a,z]:
[[ 0.          0.17073171  0.02439024]
 [ 0.02439024  0.12195122  0.        ]
 [ 0.          0.07317073  0.07317073]
 [ 0.          0.          0.14634146]
 [ 0.2195122   0.          0.        ]
 [ 0.          0.14634146  0.        ]]
Performing 10 gibbs sampling iterations
Iter 0:
Topic proportions: [ 0.24390244  0.48780488  0.26829268]
Author proportions: [ 0.12195122  0.19512195  0.24390244  0.12195122  0.24390244  0.07317073]
W_T[w,z]:
[[ 0.          0.12195122  0.07317073]
 [ 0.          0.14634146  0.07317073]
 [ 0.          0.12195122  0.07317073]
 [ 0.          0.09756098  0.04878049]
 [ 0.09756098  0.          0.        ]
 [ 0.14634146  0.          0.        ]]
A_T[a,z]:
[[ 0.          0.09756098  0.02439024]
 [ 0.          0.19512195  0.        ]
 [ 0.          0.12195122  0.12195122]
 [ 0.          0.          0.12195122]
 [ 0.24390244  0.          0.        ]
 [ 0.          0.07317073  0.        ]]
Iter 5:
Topic proportions: [ 0.17073171  0.58536585  0.24390244]
Author proportions: [ 0.09756098  0.17073171  0.29268293  0.12195122  0.17073171  0.14634146]
W_T[w,z]:
[[ 0.          0.09756098  0.09756098]
 [ 0.          0.2195122   0.        ]
 [ 0.          0.17073171  0.02439024]
 [ 0.          0.04878049  0.09756098]
 [ 0.04878049  0.04878049  0.        ]
 [ 0.12195122  0.          0.02439024]]
A_T[a,z]:
[[ 0.          0.09756098  0.        ]
 [ 0.          0.17073171  0.        ]
 [ 0.          0.2195122   0.07317073]
 [ 0.          0.          0.12195122]
 [ 0.17073171  0.          0.        ]
 [ 0.          0.09756098  0.04878049]]
Done

In [74]:
atm.show_topics()


Top 3 words per topic
[['V5' 'V1' 'V3']
 ['V4' 'V2' 'V0']
 ['V3' 'V0' 'V5']]
Top 3 authors per topic
[['A4' 'A2' 'A3']
 ['A5' 'A1' 'A2']
 ['A3' 'A5' 'A5']]

In [ ]: